!hostname
samuel
So far this semester we have been working with supervised and reinforcement learning algorithms. Another family of machine learning algorithms are unsupervised learning algorithms. These are algorithms designed to find patterns or groupings in a data set. No targets, or desired outputs, are involved.
For example, take a look at this data set of eruption durations and the waiting times in between eruptions of the Old Faithful Geyser in Yellowstone National Park.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
!head faithful.csv
"","eruptions","waiting" "1",3.6,79 "2",1.8,54 "3",3.333,74 "4",2.283,62 "5",4.533,85 "6",2.883,55 "7",4.7,88 "8",3.6,85 "9",1.95,51
datadf = pd.read_csv('faithful.csv', usecols=(1, 2))
datadf
eruptions | waiting | |
---|---|---|
0 | 3.600 | 79 |
1 | 1.800 | 54 |
2 | 3.333 | 74 |
3 | 2.283 | 62 |
4 | 4.533 | 85 |
... | ... | ... |
267 | 4.117 | 81 |
268 | 2.150 | 46 |
269 | 4.417 | 90 |
270 | 1.817 | 46 |
271 | 4.467 | 74 |
272 rows × 2 columns
data = datadf.values
data = np.array(data)
type(data), data[:10]
(numpy.ndarray, array([[ 3.6 , 79. ], [ 1.8 , 54. ], [ 3.333, 74. ], [ 2.283, 62. ], [ 4.533, 85. ], [ 2.883, 55. ], [ 4.7 , 88. ], [ 3.6 , 85. ], [ 1.95 , 51. ], [ 4.35 , 85. ]]))
plt.plot(data[:, 0], data[:, 1], '.')
plt.xlabel('duration')
plt.ylabel('interval')
Text(0, 0.5, 'interval')
We can clearly see two clusters here. For higher dimensional data, we cannot directly visualize the data to see the clusters. We need a mathematical way to detect clusters. This gives rise to the class of unsupervised learning methods called clustering algorithms.
A simple example of a clustering algorithm is the k-means algorithm. It results in identifying $k$ cluster centers. It is an iterative algorithm that starts with an initial assignment of $k$ centers. Then it proceeds by determining which centers each data sample is closest to and adjusts the centers to be the means of each of these data partitions. It then repeats.
Let's develop this algorithm one step at a time.
Each sample is the Old Faithful data has 2 attributes, so each sample is in 2-dimensional space. We know by looking at the above plot that our data nicely falls in two clusters, so we will start with $k=2$. We will initialize the two cluster centers by randomly choosing two of the data samples.
n_samples = data.shape[0]
np.random.choice(range(n_samples), 2, replace=False)
array([111, 234])
centers = data[np.random.choice(range(n_samples), 2, replace=False), :]
centers
array([[ 1.867, 48. ], [ 4.9 , 82. ]])
Now we must find all samples that are closest to the first center, and those that are closest to the second sample.
a = np.array([1, 2, 3])
b = np.array([10, 20, 30])
a, b
(array([1, 2, 3]), array([10, 20, 30]))
a - b
array([ -9, -18, -27])
But what if we want to subtract every element of a
with every element of b
?
np.resize(a, (3, 3))
array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
np.resize(b, (3, 3))
array([[10, 20, 30], [10, 20, 30], [10, 20, 30]])
np.resize(a, (3, 3)).T
array([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
np.resize(a, (3, 3)).T - np.resize(b, (3, 3))
array([[ -9, -19, -29], [ -8, -18, -28], [ -7, -17, -27]])
However, we can ask numpy to do this duplication for us if we reshape a
to be a column vector and leave b
as a row vector.
a[:, np.newaxis]
array([[1], [2], [3]])
a[:, np.newaxis] - b
array([[ -9, -19, -29], [ -8, -18, -28], [ -7, -17, -27]])
Now imagine that a
is a cluster center and b
contains data samples, one per row. The first step of calculating the distance from a
to all samples in b
is to subtract them component-wise.
a = np.array([1, 2, 3])
b = np.array([[10, 20, 30], [40, 50, 60]])
print(a)
print(b)
[1 2 3] [[10 20 30] [40 50 60]]
b - a
array([[ 9, 18, 27], [39, 48, 57]])
The single row vector a
is duplicated for as many rows as there are in b
! We can use this to calculate the squared distance between a center and every sample.
centers[0, :]
array([ 1.867, 48. ])
sqdists_to_center_0 = np.sum((centers[0, :] - data)**2, axis=1)
sqdists_to_center_0
array([9.64003289e+02, 3.60044890e+01, 6.78149156e+02, 1.96173056e+02, 1.37610756e+03, 5.00322560e+01, 1.60802589e+03, 1.37200329e+03, 9.00688900e+00, 1.37516529e+03, 3.60011560e+01, 1.30020250e+03, 9.05442889e+02, 1.01368900e+00, 1.23302589e+03, 1.60900000e+01, 1.96013689e+02, 1.30460249e+03, 1.60712890e+01, 9.66678689e+02, 9.00448900e+00, 1.01368900e+00, 9.02505889e+02, 4.42440000e+02, 6.83107556e+02, 1.22800329e+03, 4.90100000e+01, 7.88910656e+02, 9.03932289e+02, 9.67584356e+02, 6.30919489e+02, 8.47760000e+02, 3.26250000e+02, 1.02869156e+03, 6.79865156e+02, 1.60225000e+01, 0.00000000e+00, 1.03279716e+03, 1.21001156e+02, 1.77250306e+03, 1.03016529e+03, 1.00000256e+02, 1.30329000e+03, 1.00013689e+02, 6.32107556e+02, 1.22710250e+03, 2.59865156e+02, 2.50542890e+01, 1.16365076e+03, 1.21017689e+02, 7.37602489e+02, 1.77211680e+03, 3.60011560e+01, 1.03279716e+03, 3.60179560e+01, 1.23409626e+03, 5.32422500e+02, 2.56040000e+02, 8.48290000e+02, 1.09500250e+03, 1.21133956e+02, 1.30293269e+03, 1.36890000e-02, 1.16460249e+03, 1.44002500e+02, 1.94241609e+03, 9.05290000e+02, 9.08025889e+02, 2.89040000e+02, 6.33025889e+02, 1.16069156e+03, 6.40100000e+01, 9.67932689e+02, 5.33549689e+02, 1.96013456e+02, 7.94240000e+02, 1.44022500e+02, 9.07290000e+02, 7.88064256e+02, 1.22800329e+03, 7.34134756e+02, 1.16208116e+03, 4.88986289e+02, 2.89586756e+02, 6.29840000e+02, 1.60940036e+03, 7.88338889e+02, 1.03102250e+03, 9.00000000e-02, 1.44854969e+03, 1.44110889e+02, 1.77008116e+03, 4.00000000e+00, 9.08702500e+02, 2.25001156e+02, 5.81919489e+02, 1.30384000e+03, 7.32545689e+02, 9.00000000e+00, 1.16519909e+03, 1.96379456e+02, 1.60625000e+03, 1.05428900e+00, 1.23193269e+03, 1.09376549e+03, 1.00000000e+00, 1.30402589e+03, 1.60070560e+01, 1.45289829e+03, 1.09229786e+03, 7.37213956e+02, 1.21187489e+02, 1.69019909e+03, 9.67502500e+02, 1.21027889e+02, 1.09665076e+03, 4.20250000e+00, 1.37646929e+03, 1.21002500e+02, 1.52750250e+03, 2.55625000e+01, 4.45840000e+02, 8.46678689e+02, 6.40100000e+01, 1.60746929e+03, 1.09261000e+03, 9.00250000e+00, 1.16293269e+03, 4.91600000e+01, 1.77174509e+03, 9.00000000e+00, 1.23029000e+03, 6.48704890e+01, 1.68708116e+03, 4.00115600e+00, 1.16233026e+03, 9.00025600e+00, 1.45340036e+03, 2.50275560e+01, 9.64481956e+02, 1.09459796e+03, 1.44133956e+02, 1.16310756e+03, 8.49702500e+02, 7.90081156e+02, 1.21013456e+02, 1.03165076e+03, 1.02250000e+00, 2.31445229e+03, 2.50044890e+01, 8.51023556e+02, 8.45549689e+02, 2.89284089e+02, 1.09646929e+03, 5.31890000e+02, 4.88549689e+02, 1.09593269e+03, 2.02991066e+03, 2.50044890e+01, 1.68541000e+03, 9.11088900e+00, 1.44921209e+03, 1.00017689e+02, 9.03865156e+02, 3.26666689e+02, 7.91376656e+02, 2.25250000e+02, 1.60981569e+03, 1.60043560e+01, 2.03256250e+03, 1.00250000e+00, 8.10466560e+01, 8.48376656e+02, 4.02149156e+02, 1.09429000e+03, 1.09508116e+03, 6.31932689e+02, 4.30250000e+00, 1.37354969e+03, 6.81290000e+02, 4.90002560e+01, 8.48376656e+02, 1.23067869e+03, 1.22861000e+03, 9.02755600e+00, 9.06584356e+02, 1.30091066e+03, 4.00115600e+00, 1.23150250e+03, 4.90998560e+01, 1.09760249e+03, 8.10011560e+01, 7.92602489e+02, 1.30098629e+03, 8.45405801e+02, 1.09459796e+03, 1.52366669e+03, 8.47245001e+02, 9.14668900e+00, 9.07840000e+02, 1.44054289e+02, 1.16216529e+03, 1.85413476e+03, 2.50000000e+01, 9.07469289e+02, 4.00705600e+00, 8.47250000e+02, 1.29993229e+03, 1.00435600e+00, 1.23193269e+03, 5.29266256e+02, 1.03202589e+03, 1.00000000e+00, 7.32865156e+02, 2.58402500e+02, 7.89597956e+02, 2.52840890e+01, 2.12460249e+03, 4.90176890e+01, 7.89212089e+02, 4.00000000e+00, 1.16176000e+03, 3.60136890e+01, 7.35843456e+02, 9.04549689e+02, 9.66062500e+02, 9.04910656e+02, 9.05760000e+02, 4.88202500e+02, 9.68198489e+02, 4.88910656e+02, 3.63025000e+01, 1.44936386e+03, 4.12250000e+00, 1.77067189e+03, 3.60002560e+01, 3.60002890e+01, 8.46837056e+02, 9.65338889e+02, 2.56217156e+02, 7.34212089e+02, 1.23328900e+00, 1.45340036e+03, 2.26067089e+02, 1.37637666e+03, 1.15986516e+03, 8.10466560e+01, 1.16225000e+03, 3.61070756e+02, 6.82165289e+02, 3.61108890e+01, 1.23167189e+03, 6.27890000e+02, 6.31932689e+02, 1.60521209e+03, 1.02780250e+03, 5.33202500e+02, 1.23167189e+03, 6.40176890e+01, 9.66837056e+02, 9.08410000e+02, 1.30310756e+03, 1.00000289e+02, 1.23067869e+03, 2.50134560e+01, 1.44146689e+02, 7.37311689e+02, 1.09406250e+03, 4.08008900e+00, 1.77050250e+03, 4.00250000e+00, 6.82760000e+02])
sqdists_to_center_1 = np.sum((centers[1, :] - data)**2, axis=1)
sqdists_to_center_1
array([1.06900000e+01, 7.93610000e+02, 6.64554890e+01, 4.06848689e+02, 9.13468900e+00, 7.33068289e+02, 3.60400000e+01, 1.06900000e+01, 9.69702500e+02, 9.30250000e+00, 7.93406489e+02, 4.96628900e+00, 1.64900000e+01, 1.23492250e+03, 1.04000000e+00, 9.07469289e+02, 4.09922500e+02, 4.01000000e+00, 9.10890000e+02, 9.42250000e+00, 9.70610000e+02, 1.23492250e+03, 1.81025000e+01, 1.72359889e+02, 6.41346890e+01, 2.69000000e+00, 7.37602489e+02, 3.66674890e+01, 1.71025000e+01, 9.21808900e+00, 8.13600000e+01, 2.51874890e+01, 2.58350089e+02, 4.75168900e+00, 6.51384890e+01, 9.08311689e+02, 1.16519909e+03, 4.00448900e+00, 5.38406489e+02, 6.40136890e+01, 4.30250000e+00, 5.85102289e+02, 4.11088900e+00, 5.85922500e+02, 8.11346890e+01, 3.50588900e+00, 3.25138489e+02, 8.48840000e+02, 7.12890000e-02, 5.37410000e+02, 4.90100000e+01, 6.40338560e+01, 7.93406489e+02, 4.00448900e+00, 7.94029889e+02, 1.00028900e+00, 1.22399489e+02, 3.34452289e+02, 2.51108890e+01, 1.33988900e+00, 5.36112889e+02, 4.16000000e+00, 1.16592250e+03, 1.00000000e-02, 4.93504889e+02, 1.00250000e+02, 1.65372890e+01, 1.60400000e+01, 2.97025889e+02, 8.10400000e+01, 7.51689000e-01, 6.84602489e+02, 9.16000000e+00, 1.21810000e+02, 4.08508889e+02, 3.60278890e+01, 4.92311689e+02, 1.61108890e+01, 3.70342890e+01, 2.69000000e+00, 4.95882890e+01, 3.21489000e-01, 1.44640000e+02, 2.94139289e+02, 8.16938890e+01, 3.60010890e+01, 3.69025000e+01, 4.14668900e+00, 1.16346929e+03, 1.68100000e+01, 4.91290000e+02, 6.43214890e+01, 1.03319909e+03, 1.60068890e+01, 3.70406489e+02, 1.00360000e+02, 4.05428900e+00, 5.03225000e+01, 9.70199089e+02, 0.00000000e+00, 4.05841889e+02, 3.62840890e+01, 1.09684000e+03, 1.16000000e+00, 1.72250000e+00, 1.23419909e+03, 4.04000000e+00, 9.09715689e+02, 1.60025000e+01, 2.48108900e+00, 4.90278890e+01, 5.35760000e+02, 4.90000000e+01, 9.23328900e+00, 5.39240000e+02, 1.07128900e+00, 1.03067189e+03, 9.09000000e+00, 5.38504889e+02, 2.52332890e+01, 8.46212089e+02, 1.69693889e+02, 2.54225000e+01, 6.84602489e+02, 3.60900000e+01, 2.28368900e+00, 1.37789829e+03, 1.60000000e-01, 7.35932689e+02, 6.40625000e+01, 1.37819909e+03, 1.53728900e+00, 6.80410000e+02, 4.93214890e+01, 1.30540649e+03, 2.67289000e-01, 9.70102289e+02, 1.60010890e+01, 8.49219689e+02, 1.03618890e+01, 1.44488900e+00, 4.91112889e+02, 1.34689000e-01, 2.50068890e+01, 3.63214890e+01, 5.37508889e+02, 4.07128900e+00, 1.09731169e+03, 1.96040000e+02, 8.50610000e+02, 2.50176890e+01, 2.58100000e+01, 2.95250000e+02, 1.09000000e+00, 1.22776889e+02, 1.44810000e+02, 1.16000000e+00, 1.21667489e+02, 8.50610000e+02, 4.98704890e+01, 1.37629000e+03, 1.65625000e+01, 5.84410000e+02, 1.71384890e+01, 2.57960000e+02, 3.61004890e+01, 3.67416089e+02, 3.60100000e+01, 9.08803089e+02, 1.21080089e+02, 1.09789829e+03, 6.32935489e+02, 2.51004890e+01, 1.98455489e+02, 1.53728900e+00, 1.32148900e+00, 8.11600000e+01, 1.03016529e+03, 9.81000000e+00, 6.45372890e+01, 7.38102289e+02, 2.51004890e+01, 1.42250000e+00, 2.28368900e+00, 9.69219689e+02, 1.62180890e+01, 4.66748900e+00, 1.30540649e+03, 1.23328900e+00, 7.36382089e+02, 1.01000000e+00, 6.34406489e+02, 3.60100000e+01, 4.64000000e+00, 2.58723560e+01, 1.44488900e+00, 2.69600000e+01, 2.52851560e+01, 9.68022500e+02, 1.60542890e+01, 4.91840000e+02, 3.02500000e-01, 8.15882890e+01, 8.50199089e+02, 1.60900000e+01, 1.30571569e+03, 2.52840890e+01, 5.10250000e+00, 1.09780309e+03, 1.16000000e+00, 1.27335289e+02, 4.04000000e+00, 1.09819909e+03, 5.01384890e+01, 3.26199289e+02, 3.64448890e+01, 8.47250000e+02, 1.44010000e+02, 7.37410000e+02, 3.65625000e+01, 1.03319909e+03, 4.00689000e-01, 7.93922500e+02, 4.91738890e+01, 1.68100000e+01, 9.61308900e+00, 1.66674890e+01, 1.64006890e+01, 1.44966289e+02, 9.12250000e+00, 1.44667489e+02, 7.90165289e+02, 1.65140890e+01, 1.03119849e+03, 6.42025000e+01, 7.93102289e+02, 7.93302500e+02, 2.53806890e+01, 9.90250000e+00, 3.30589489e+02, 4.95625000e+01, 1.23150250e+03, 1.60010890e+01, 3.65000000e+02, 9.10048900e+00, 1.13848900e+00, 6.32935489e+02, 2.84089000e-01, 2.32656289e+02, 6.43025000e+01, 7.91290000e+02, 1.20250000e+00, 8.27768890e+01, 8.11600000e+01, 3.65625000e+01, 5.17288900e+00, 1.21966289e+02, 1.20250000e+00, 6.84410000e+02, 9.38068900e+00, 1.60176890e+01, 4.13468900e+00, 5.85302500e+02, 1.42250000e+00, 1.52950889e+03, 4.91022500e+02, 4.90225000e+01, 1.61308900e+00, 1.30356250e+03, 6.42332890e+01, 1.30550489e+03, 6.41874890e+01])
And, which samples are closest to the first center?
sqdists_to_center_0 < sqdists_to_center_1
array([False, True, False, True, False, True, False, False, True, False, True, False, False, True, False, True, True, False, True, False, True, True, False, False, False, False, True, False, False, False, False, False, False, False, False, True, True, False, True, False, False, True, False, True, False, False, True, True, False, True, False, False, True, False, True, False, False, True, False, False, True, False, True, False, True, False, False, False, True, False, False, True, False, False, True, False, True, False, False, False, False, False, False, True, False, False, False, False, True, False, True, False, True, False, True, False, False, False, True, False, True, False, True, False, False, True, False, True, False, False, False, True, False, False, True, False, True, False, True, False, True, False, False, True, False, False, True, False, True, False, True, False, True, False, True, False, True, False, True, False, False, True, False, False, False, True, False, True, False, True, False, False, True, False, False, False, False, False, True, False, True, False, True, False, False, False, True, False, True, False, True, True, False, False, False, False, False, True, False, False, True, False, False, False, True, False, False, True, False, True, False, True, False, False, False, False, False, False, True, False, True, False, False, True, False, True, False, False, True, False, False, False, True, False, True, False, True, False, True, False, True, False, True, False, False, False, False, False, False, False, False, True, False, True, False, True, True, False, False, True, False, True, False, True, False, False, True, False, False, False, True, False, False, False, False, False, False, False, True, False, False, False, True, False, True, True, False, False, True, False, True, False])
This approach is easy for $k=2$, but what if $k$ is larger. Can we calculate all of the needed distances in one numpy
expression? I bet we can!
centers[:, np.newaxis, :].shape, data.shape
((2, 1, 2), (272, 2))
(centers[:, np.newaxis, :] - data).shape
(2, 272, 2)
np.sum((centers[:, np.newaxis, :] - data)**2, axis=-1).shape
(2, 272)
data.shape
(272, 2)
These are the square distances between each of our two centers and each of the 272 samples. If we take the argmin
across the two rows, we will have the index of the closest center for each of the 272 samples.
clusters = np.argmin(np.sum((centers[:, np.newaxis, :] - data)**2, axis=2), axis=0)
clusters
array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1])
Now, to calculate the new values of our two centers, we just calculate the mean of the appropriate samples.
data[clusters == 0, :].mean(axis=0)
array([ 2.06631959, 54.39175258])
data[clusters == 1, :].mean(axis=0)
array([ 4.27568 , 80.04571429])
Can do both in a for loop.
k = 2
for i in range(k):
centers[i, :] = data[clusters == i, :].mean(axis=0)
centers
array([[ 2.06631959, 54.39175258], [ 4.27568 , 80.04571429]])
Now, we can wrap these steps in our first version of a kmeans
function.
def kmeans(data, k = 2, n_iterations = 5):
# Initial centers
centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
# Repeat n times
for iteration in range(n_iterations):
# Which center is each sample closest to?
closest = np.argmin(np.sum((centers[:, np.newaxis, :] - data)**2, axis=2), axis=0)
# Update cluster centers
for i in range(k):
centers[i, :] = data[closest == i, :].mean(axis=0)
return centers
kmeans(data, 2, 5)
array([[ 4.29793023, 80.28488372], [ 2.09433 , 54.75 ]])
kmeans(data, 2)
array([[ 2.09433 , 54.75 ], [ 4.29793023, 80.28488372]])
We need a measure of the quality of our clustering. For this, we define $J$, which is a performance measure being minimized by k-means. It is defined as $$ J = \sum_{n=1}^N \sum_{k=1}^K r_{nk} ||\mathbf{x}_n - \mathbf{\mu}_k||^2 $$ where $N$ is the number of samples, $K$ is the number of cluster centers, $\mathbf{x}_n$ is the $n^{th}$ sample and $\mathbf{\mu}_k$ is the $k^{th}$ center, each being an element of $\mathbf{R}^p$ where $p$ is the dimensionality of the data. $r_{nk}$ is 1 if $\mathbf{x}_n$ is closest to center $\mathbf{\mu}_k$, and 0 otherwise.
The sums can be computed using python for loops, but, as you know, for loops are much slower than matrix operations in python, so let's do the matrix magic. We already know how to calculate the difference between all samples and all centers.
sqdists = np.sum((centers[:,np.newaxis,:] - data)**2, axis=2)
sqdists.shape
(2, 272)
The calculation of $J$ requires us to multiply the squared differences of the each component by $r_{nk}$. Since we already have all of the squared distances, let's just sum up the minimum distances for each sample.
np.min(sqdists, axis=0)
array([1.55006183e+00, 2.24396205e-01, 3.74393068e+01, 5.79323792e+01, 2.46111605e+01, 1.03693182e+00, 6.34507087e+01, 2.50014904e+01, 1.15175158e+01, 2.45504704e+01, 2.07908112e-01, 1.57650269e+01, 4.19067440e+00, 5.47380642e+01, 8.90785154e+00, 5.73061694e+00, 5.79854869e+01, 1.59112870e+01, 5.93793435e+00, 1.09417783e+00, 1.15749117e+01, 5.47380642e+01, 4.86669440e+00, 1.23468711e+02, 3.66168748e+01, 9.18434754e+00, 3.79829308e-01, 1.64049297e+01, 4.36615040e+00, 1.11826795e+00, 4.96426813e+01, 9.31297885e+00, 1.36443178e+02, 6.09833783e-02, 3.67466268e+01, 5.72291281e+00, 4.08942293e+01, 3.12695378e-01, 2.12903823e+01, 9.93451777e+01, 7.61325832e-03, 1.30530555e+01, 1.57212429e+01, 1.31195075e+01, 4.97083034e+01, 9.64687142e+00, 9.54395782e+01, 1.93810961e+00, 3.94691024e+00, 2.12403426e+01, 2.57341441e+01, 9.92816858e+01, 2.07908112e-01, 3.12695378e-01, 2.64572029e-01, 9.09664166e+00, 8.21370703e+01, 9.24778747e+01, 9.36124285e+00, 9.12368567e-01, 2.12637267e+01, 1.56866950e+01, 4.09545591e+01, 4.09414412e+00, 3.15145994e+01, 1.42920402e+02, 4.19675828e+00, 4.36499440e+00, 1.12534914e+02, 4.98221373e+01, 3.87812624e+00, 2.59632415e+00, 1.14383783e+00, 8.19009464e+01, 5.78923710e+01, 1.69939914e+01, 3.14548716e+01, 4.26981428e+00, 1.65220017e+01, 9.18434754e+00, 2.54795902e+01, 3.82251824e+00, 1.00947239e+02, 1.12856040e+02, 4.96856371e+01, 6.37027308e+01, 1.64738715e+01, 6.03251383e-02, 4.08646376e+01, 3.55295178e+01, 3.14703096e+01, 9.90910897e+01, 1.93272190e+01, 4.47797428e+00, 7.41563617e+01, 6.47341098e+01, 1.57895069e+01, 2.57355721e+01, 1.15437138e+01, 4.20900812e+00, 5.80590514e+01, 6.32790006e+01, 2.90721302e+01, 8.77812354e+00, 9.61592687e-01, 5.46777345e+01, 1.58164230e+01, 5.80075038e+00, 3.57833618e+01, 1.26193081e+00, 2.56683742e+01, 2.12905508e+01, 8.05690081e+01, 1.11348971e+00, 2.13701343e+01, 1.03833881e+00, 1.93503314e+01, 2.46501304e+01, 2.12981046e+01, 4.83820611e+01, 2.24022415e+00, 1.22051351e+02, 9.27703497e+00, 2.59632415e+00, 6.33758447e+01, 1.16941657e+00, 8.82273128e+01, 3.86955212e+00, 4.10237555e-01, 9.92279195e+01, 8.82447448e+01, 8.73961542e+00, 3.12474672e+00, 8.01825182e+01, 7.04759493e+01, 3.83075024e+00, 1.15375916e+01, 3.58855879e+01, 1.93808543e+00, 1.38801995e+00, 9.12482807e-01, 3.14802215e+01, 3.88544624e+00, 9.56940285e+00, 1.63710897e+01, 2.12428865e+01, 1.29767378e-01, 2.90734283e+01, 2.55218736e+02, 2.00790136e+00, 9.84990909e+00, 9.35237497e+00, 1.12646256e+02, 1.01584469e+00, 8.23271743e+01, 1.00992375e+02, 9.60980687e-01, 1.67850644e+02, 2.00790136e+00, 8.02745160e+01, 8.82228869e+01, 3.54693138e+01, 1.30238478e+01, 4.38091252e+00, 1.36806848e+02, 1.64622497e+01, 7.41923324e+01, 6.37953007e+01, 5.73825450e+00, 1.67930018e+02, 2.90932922e+01, 6.80323285e+00, 9.37082109e+00, 1.45987878e+02, 9.22472567e-01, 9.13946807e-01, 4.96924093e+01, 1.94104675e+01, 2.46209464e+01, 3.65624726e+01, 4.03570998e-01, 9.37082109e+00, 8.72846354e+00, 8.98655942e+00, 1.15050957e+01, 4.20969652e+00, 1.56735011e+01, 7.04759493e+01, 8.74777542e+00, 3.83579246e-01, 1.18557269e+00, 6.85739265e+00, 1.66427155e+01, 1.56672390e+01, 9.37227721e+00, 9.12482807e-01, 4.89637693e+01, 9.28453321e+00, 1.15377240e+01, 4.33807828e+00, 3.14535735e+01, 3.82475612e+00, 1.20016733e+02, 1.97670353e+00, 4.29013040e+00, 7.05017813e+01, 9.28471485e+00, 1.58175790e+01, 2.90887700e+01, 8.77812354e+00, 8.54071845e+01, 1.82137258e-01, 2.91107242e+01, 2.56551982e+01, 9.41427561e+01, 1.63696257e+01, 2.04831785e+00, 1.94997001e+02, 3.74363215e-01, 1.63835995e+01, 1.93272190e+01, 3.81930800e+00, 2.53528163e-01, 2.55022142e+01, 4.26094640e+00, 1.11869771e+00, 4.22207252e+00, 4.18502228e+00, 1.01045027e+02, 1.16876983e+00, 1.00953501e+02, 2.76446833e-01, 3.54621079e+01, 1.93101953e+01, 9.91181915e+01, 1.87076153e-01, 2.00264246e-01, 9.27642909e+00, 1.19958583e+00, 9.23895370e+01, 2.54750281e+01, 5.47184807e+01, 3.58855879e+01, 7.47969467e+01, 2.46393925e+01, 4.01519824e+00, 6.80323285e+00, 3.82757200e+00, 1.58972349e+02, 3.65561847e+01, 1.71340534e-01, 8.75819154e+00, 5.01443171e+01, 4.96924093e+01, 6.32864567e+01, 2.12477138e-01, 8.19535983e+01, 8.75819154e+00, 2.59085806e+00, 1.09357195e+00, 4.42634228e+00, 1.57025891e+01, 1.30662436e+01, 8.72846354e+00, 1.29778969e+02, 3.14861776e+01, 2.56842121e+01, 9.35840567e-01, 7.04285137e+01, 9.91077754e+01, 7.04836716e+01, 3.65872646e+01])
np.sum(np.min(sqdists, axis=0))
8924.605200606595
Let's define a function named calcJ to do this calculation.
def calcJ(data, centers):
sqdists = np.sum((centers[:, np.newaxis, :] - data)**2, axis=2)
return np.sum(np.min(sqdists, axis=0))
calcJ(data, centers)
8924.605200606595
Now we can add this calculation to track the value of $J$ for each iteration as a kind of learning curve. $J$ measures the average "spread" within each cluster, so the smaller it is, the better.
def kmeans(data, k, n_iterations):
# Initialize centers and list J to track performance metric
centers = data[np.random.choice(range(data.shape[0]), k, replace=False), :]
J = []
for iteration in range(n_iterations):
# Which center is each sample closest to?
sqdistances = np.sum((centers[:, np.newaxis, :] - data)**2, axis=2)
closest = np.argmin(sqdistances, axis=0)
# Calculate J and append to list J
J.append(calcJ(data, centers))
# Update cluster centers
for i in range(k):
centers[i, :] = data[closest == i, :].mean(axis=0)
# Calculate J one final time and return results
J.append(calcJ(data, centers))
return centers, J, closest
centers, J, closest = kmeans(data, 2, 10)
J
[76592.212941, 31417.137381206157, 17979.045594406154, 10358.428460937701, 8925.715281184834, 8901.76872094721, 8901.76872094721, 8901.76872094721, 8901.76872094721, 8901.76872094721, 8901.76872094721]
plt.plot(J);
centers, J, closest = kmeans(data, 2, 10)
plt.plot(J);
centers
array([[ 4.29793023, 80.28488372], [ 2.09433 , 54.75 ]])
closest
array([0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0])
centers, J, closest = kmeans(data, 2, 2)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green", marker='D')
plt.subplot(1, 2, 2)
plt.plot(J)
centers
array([[ 2.02314444, 53.61111111], [ 4.21205495, 79.44505495]])
Let's try for more iterations.
centers, J, closest = kmeans(data, 2, 10)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green", marker='D')
plt.subplot(1, 2, 2)
plt.plot(J)
centers
array([[ 4.29793023, 80.28488372], [ 2.09433 , 54.75 ]])
Now, how about three centers, so $k=3$?
centers, J, closest = kmeans(data, 3, 10)
plt.figure(figsize=(15, 8))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green", marker='D')
plt.subplot(1, 2, 2)
plt.plot(J)
centers
array([[ 3.66606667, 70.65 ], [ 2.00583133, 52.86746988], [ 4.35836434, 82.6124031 ]])
centers, J, closest = kmeans(data, 4, 10)
plt.figure(figsize=(15, 8))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green", marker='D')
plt.subplot(1, 2, 2)
plt.plot(J)
centers
array([[ 1.99635593, 50.6440678 ], [ 2.26145238, 60.83333333], [ 4.2403908 , 75.95402299], [ 4.3690119 , 84.91666667]])
Or six centers?
centers, J, closest = kmeans(data, 6, 20)
plt.figure(figsize=(15, 8))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0], data[:, 1], s=80, c=closest, alpha=0.5)
plt.scatter(centers[:, 0], centers[:, 1], s=80, c="green", marker='D')
plt.subplot(1, 2, 2)
plt.plot(J)
centers
array([[ 4.29412281, 81.21052632], [ 4.347875 , 85. ], [ 2.26965789, 61.34210526], [ 4.54438095, 90.19047619], [ 4.22268116, 75.04347826], [ 2.0082381 , 50.98412698]])
So, clustering two-dimensional data is not all that exciting. How about 784-dimensional data, such as our good buddy the MNIST data set?
import gzip
import pickle
with gzip.open('mnist.pkl.gz', 'rb') as f:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
Xtrain = train_set[0]
Ttrain = train_set[1].reshape((-1,1))
Xtest = test_set[0]
Ttest = test_set[1].reshape((-1,1))
Xtrain.shape, Ttrain.shape, Xtest.shape, Ttest.shape
((50000, 784), (50000, 1), (10000, 784), (10000, 1))
How many clusters shall we use?
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=10)
plt.plot(J);
centers.shape
(10, 784)
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
plt.axis('off')
Try more iterations.
centers, J, closest = kmeans(Xtrain, k=10, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), cmap='gray')
plt.axis('off')
and more centers
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
plt.subplot(4, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')
Try that again. Do the cluster centers differ?
centers, J, closest = kmeans(Xtrain, k=20, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(20):
plt.subplot(4, 5, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')
Maybe more clusters will help. Do we see expected variations for each digit?
centers, J, closest = kmeans(Xtrain, k=40, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(40):
plt.subplot(4, 10, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')
centers, J, closest = kmeans(Xtrain, k=40, n_iterations=20)
plt.plot(J)
plt.figure()
for i in range(40):
plt.subplot(4, 10, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.axis('off')
Humm. Some of these look pretty fuzzy. Let's see how many samples are in each cluster. Show the counts in the titles.
(closest == 0).sum()
1239
for i in range(40):
plt.subplot(4, 10, i + 1)
plt.imshow(-centers[i, :].reshape((28, 28)), interpolation='nearest', cmap='gray')
plt.title(str((closest == i).sum()))
plt.axis('off')
How could you use the results of the kmeans
clustering algorithm as the first step in a classification algorithm?
Now that we have some experience in calculating distances between samples, we are a short step away from an implementation of a common classification algorithm called k-nearest-neighbor. This is a non-parametric algorithm, meaning that it does not involve parameters, like weights, to make its decisions. Instead, we could call it a memory-based method. The algorithm classifies a sample by determining the $k$ closest samples in the training set and returns the most common class label among those $k$ nearest samples.
Training is terribly simple. We just have to store the training samples. Classification is also trivial to code. We just calculate squared distances between training samples and the samples being classified and return the most common class label among the $k$ closest training samples.
Let's create a class named KNN
to implement this algorithm.
First, let's practice our numpy
-foo to see how to pick the most common class, with a minimum amount of code.
Remember that sqdists
from above is n_centers
x n_samples
.
Let's try to classify the first three MNIST test samples.
sqdists = np.sum((Xtest[:3, np.newaxis, :] - Xtrain)**2, axis=2)
sqdists.shape
(3, 50000)
Okay. Now all we have to do is find the $k$ smallest distances in each row. Let's use $k=5$.
k = 5
np.sort(sqdists[0, :])[:k]
array([ 9.6193695, 11.355759 , 11.403915 , 12.214478 , 12.627594 ], dtype=float32)
But, we need the indices of these values so we can look up their class labels in T
.
k = 5
np.argsort(sqdists[0, :])[:k]
array([38620, 16186, 27059, 47003, 14563])
Now we have to do this for each row in sqdists
. Or do we? Wouldn't it be nice if np.argsort
sorts each row independently so we can do this in one function call?
np.sort(sqdists, axis=1)
array([[ 9.6193695, 11.355759 , 11.403915 , ..., 211.29517 , 211.7034 , 235.06609 ], [ 20.636139 , 22.408554 , 25.232117 , ..., 215.19962 , 216.1485 , 221.52444 ], [ 1.6865845, 1.7748108, 2.063202 , ..., 227.82173 , 245.48021 , 251.6908 ]], dtype=float32)
Yippee!
np.argsort(sqdists, axis=1)
array([[38620, 16186, 27059, ..., 10259, 25321, 41358], [28882, 49160, 24612, ..., 43452, 10237, 13650], [46512, 15224, 47333, ..., 25321, 25285, 41358]])
indices = np.argsort(sqdists, axis=1)
indices
array([[38620, 16186, 27059, ..., 10259, 25321, 41358], [28882, 49160, 24612, ..., 43452, 10237, 13650], [46512, 15224, 47333, ..., 25321, 25285, 41358]])
np.squeeze(Ttrain[indices, :]).shape
(3, 50000)
plt.imshow(-Xtest[2, :].reshape(28,28), cmap='gray')
plt.axis('off')
(-0.5, 27.5, 27.5, -0.5)
Ttrain[indices, :][:, :, 0]
array([[7, 7, 7, ..., 0, 8, 0], [2, 2, 2, ..., 4, 0, 4], [1, 1, 1, ..., 8, 0, 0]])
np.unique(Ttrain[indices, :][:, :, 0][:, :40], axis=1, return_counts=True)
(array([[7], [2], [1]]), array([40]))
Cool! Now we just have to take the first $k$ columns of these and determine the most common label across the columns, for each row. We can use scipy.stats.mode
for this!
import scipy.stats as ss
ss.mode([1, 2, 3, 4, 2, 2, 2]) #, keepdims=True)
ModeResult(mode=2, count=4)
ss.mode(Ttrain[indices, :][:, :, 0][:, :10], axis=1) #, keepdims=True)
ModeResult(mode=array([7, 2, 1]), count=array([10, 10, 10]))
Ttrain[indices, :][:, :, 0][:, :10]
array([[7, 7, 7, 7, 7, 7, 7, 7, 7, 7], [2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Ttest[:3]
array([[7], [2], [1]])
Well, maybe we will do better with different values of $k$.
Finally, we can now define our KNN
class.
import numpy as np
import scipy.stats as ss # for ss.mode
class KNN():
def __init__(self):
self.X = None # data will be stored here
self.T = None # class labels will be stored here
self.Xmeans = None
self.Xstds = None
def train(self, X, T):
if self.Xmeans is None:
self.Xmeans = X.mean(axis=0)
self.Xstds = X.std(axis=0)
self.Xstds[self.Xstds == 0] = 1
self.X = self._standardizeX(X)
self.T = T
def _standardizeX(self, X):
return (X - self.Xmeans) / self.Xstds
def use(self, Xnew, k = 1):
self.k = k
# Calc squared distance from all samples in Xnew with all stored in self.X
sqdists = np.sum( (self._standardizeX(Xnew)[:, np.newaxis, :] - self.X)**2, axis=-1 )
# sqdists is now n_new_samples x n_train_samples
# Sort each row of squared distances from smallest to largest and select the first k.
indices = np.argsort(sqdists, axis=1)[:, :k]
# Determine mose common class label in each row.
classes = ss.mode(self.T[indices, :][:, :, 0], axis=1)[0]
return classes
knn = KNN()
knn
<__main__.KNN at 0x7f9bb2fa29d0>
Oh, can't have that!!
import numpy as np
import scipy.stats as ss # for ss.mode
class KNN():
def __init__(self):
self.X = None # data will be stored here
self.T = None # class labels will be stored here
self.Xmeans = None
self.Xstds = None
def __repr__(self):
if self.X is None:
return f'KNN() has not been trained.'
else:
return f'KNN(), trained with {self.X.shape[0]} samples having class labels {np.unique(self.T)}.'
def train(self, X, T):
if self.Xmeans is None:
self.Xmeans = X.mean(axis=0)
self.Xstds = X.std(axis=0)
self.Xstds[self.Xstds == 0] = 1
self.X = self._standardizeX(X)
self.T = T
return self
def _standardizeX(self, X):
return (X - self.Xmeans) / self.Xstds
def use(self, Xnew, k = 1):
if self.X is None:
raise Exception('KNN object has not been trained yet.')
self.k = k
# Calc squared distance from all samples in Xnew with all stored in self.X
sqdists = np.sum( (self._standardizeX(Xnew)[:, np.newaxis, :] - self.X)**2, axis=-1 )
# sqdist
# sqdists is now n_new_samples x n_train_samples
# Sort each row of squared distances from smallest to largest and select the first k.
indices = np.argsort(sqdists, axis=1)[:, :k]
# Determine mose common class label in each row.
classes = ss.mode(self.T[indices, :][:, :, 0], axis=1)[0]
return classes
knn = KNN()
knn
KNN() has not been trained.
knn.train(Xtrain, Ttrain)
KNN(), trained with 50000 samples having class labels [0 1 2 3 4 5 6 7 8 9].
Boy, that took a long time to train! :) 200 ms.
Let's test it. First, use the default value for $k$ of 1.
knn.use(Xtest[:3, :])
array([7, 2, 1])
Ttest[:3]
array([[7], [2], [1]])
Well, that worked perfectly. Let's try more test samples.
knn.use(Xtest[:10, :])
array([7, 2, 1, 0, 4, 1, 4, 4, 4, 9])
Ttest[:10]
array([[7], [2], [1], [0], [4], [1], [4], [9], [5], [9]])
There are some mistakes. How about using more neighbors?
plt.imshow(-Xtest[8, :].reshape(28, 28), cmap='gray')
plt.axis('off')
(-0.5, 27.5, 27.5, -0.5)
knn.use(Xtest[:10, :], k=7)
array([7, 2, 1, 0, 4, 1, 4, 9, 4, 9])
Ttest[:10]
array([[7], [2], [1], [0], [4], [1], [4], [9], [5], [9]])
def percent_correct(Predicted, T):
return 100 * np.mean(Predicted == T)
percent_correct(knn.use(Xtest[:10, :], k=5), Ttest[:10])
17.0
Now we can try multiple values of $k$ with a for loop, and test all test samples.
# pc = []
# for k in range(1, 5):
# print(k, end=' ')
# pc = percent_correct(knn.use(Xtest, k=k), Ttest)
# pc.append([k, pc])
Python kernel died.
Well, here is what we often face when dealing with big data sets. K-nearest-neighbors calculates squared distances between each train and test sample. That can get huge.
We can deal with this the typical way of working with batches of data.
n_train = 5000 # To reduce computation time
knn = KNN()
knn.train(Xtrain[:n_train, :], Ttrain[:n_train, :])
batch_size = 500
n_samples = Xtest.shape[0]
results = []
for k in [1, 2, 5, 10, 20]:
n_correct = 0
for first in range(0, n_samples, batch_size):
X = Xtest[first:first + batch_size, :]
T = Ttest[first:first + batch_size, :]
n_correct += np.sum(knn.use(X, k=k) == T.reshape(-1))
pc = n_correct / n_samples * 100
results.append([k, pc])
print(results[-1])
[1, 89.18] [2, 87.77000000000001] [5, 89.17] [10, 88.74] [20, 87.39]
results
[[1, 89.18], [2, 87.77000000000001], [5, 89.17], [10, 88.74], [20, 87.39]]
results = np.array(results)
plt.plot(results[:, 0], results[:, 1])
plt.xlabel('$k$')
plt.ylabel('Percent Correct Test Data');
How might you change the implementation of KNN
to speed up this calculation using multiple $k$ values?
How what might you change to speed up the calculations for a single $k$ value? (Hint: ever heard of a kd-tree)
How could you calculate class probabilities with KNN
?
n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))
plt.scatter(X[:, 0], X[:, 1], c=T, s=80);
plt.figure(figsize=(8, 8))
n = 20
X = np.random.multivariate_normal([5, 7], [[0.8, -0.5], [-0.5, 0.8]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[0.6, 0.5], [0.5, 0.8]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))
# Make samples as coordinates of grid points across 2-dimensional data space
m = 100
xs = np.linspace(0, 10, m)
ys = xs
Xs, Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T
knn = KNN()
knn.train(X, T)
classes = knn.use(samples, k=1)
plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
plt.scatter(X[:, 0], X[:, 1], s=60, c=T);
Ooo, that's a cool plot. Let's show similar plots for the classifiers we have studied so far, including LDA, QDA, k-NN, and neural nets.
def plot_result(X, Xs, Ys, classes):
plt.contourf(Xs, Ys, classes.reshape(Xs.shape), 1, colors=('blue','red'), alpha=0.2)
plt.scatter(X[:, 0], X[:, 1], s=60, c=T);
import neuralnetworksA4 as nn
import qdalda
n = 40
X = np.random.multivariate_normal([5, 6], [[0.9, -0.2], [-0.2, 0.9]], n)
X = np.vstack((X,
np.random.multivariate_normal([6, 3], [[2, 0.4], [0.4, 2]], n)))
T = np.vstack((np.ones((n, 1)), 2 * np.ones((n, 1))))
m = 100
xs = np.linspace(0, 10, m)
ys = xs
Xs,Ys = np.meshgrid(xs, ys)
samples = np.vstack((Xs.ravel(), Ys.ravel())).T
plt.figure(figsize=(20, 30))
# Create and train Quadratic Discriminant Analysis (QDA)
# and Linear Discriminant Analysis (LDA) Classifiers
qda = qdalda.QDA()
qda.train(X, T)
lda = qdalda.LDA()
lda.train(X, T)
# Create and train k-nearest-neighbor (KNN) classifier
knn = KNN()
knn.train(X, T)
ploti = 0
# Use and plot results for LDA and QDA
ploti += 1
plt.subplot(5, 3, ploti)
classes = lda.use(samples)
plot_result(X, Xs, Ys, classes)
plt.title('LDA')
ploti += 1
plt.subplot(5, 3, ploti)
classes = qda.use(samples)
plot_result(X, Xs, Ys, classes)
plt.title('QDA')
ploti += 1
# Use and plot results for KNN with various values of k
for k in [1, 2, 3, 5, 10, 20]:
ploti += 1
plt.subplot(5, 3, ploti)
classes = knn.use(samples, k)
plot_result(X, Xs, Ys, classes)
plt.title(f'KNN k={k}')
# Use and plot results for neural networks with various hidden layer structures
for n_hiddens in [[], [1], [2], [10], [10, 10], [5, 5, 5, 5]]:
ploti += 1
plt.subplot(5, 3, ploti)
nnet = nn.NeuralNetworkClassifier(2, n_hiddens, 2)
nnet.train(X, T, X, T, n_epochs=1000, method='scg', verbose=False)
classes, _ = nnet.use(samples)
plot_result(X, Xs, Ys, classes)
plt.title(f'nnet {n_hiddens}')